library(tidyverse)
library(brms)
library(tidybayes)
library(bayesplot)
## This is bayesplot version 1.7.0
## - Online documentation and vignettes at mc-stan.org/bayesplot
## - bayesplot theme set to bayesplot::theme_default()
##    * Does _not_ affect other ggplot2 plots
##    * See ?bayesplot_theme_set for details on theme setting
library(cowplot)
library(scales)
library(hexbin)
library(glue)

devAskNewPage(ask = FALSE)

theme_set(theme_light())

source(glue("{params$common_dir_str}/simulation.R"))
source(glue("{params$common_dir_str}/brms_model.R"))
source(glue("{params$model_dir_str}/model_prior.R"))

print(bprior_full)
##              prior class        coef group resp   dpar nlpar bound
## 1 normal(3.8, 0.5)     b   intercept            circSD            
## 2   normal(0, 0.5)     b stimulation            circSD            
## 3   normal(0, 0.5)    sd   Intercept  subj      circSD            
## 4   normal(0, 0.5)    sd stimulation  subj      circSD            
## 5   normal(0, 1.5)     b   intercept             theta            
## 6   normal(0, 1.5)     b stimulation             theta            
## 7     normal(0, 1)    sd   Intercept  subj       theta            
## 8     normal(0, 1)    sd stimulation  subj       theta

load sim data

conditions <- c(0,1)

sim_datasets_fpath <- glue("{params$save_dir_str}/sim_datasets.rds")

if (file.exists(sim_datasets_fpath)){
  
  sim_datasets <- readRDS(sim_datasets_fpath)

} else { 
  
  print("simulating")
  
  nsim_datasets <- 1
  
  sim_priors <- tibble(
    sim_num = 1:nsim_datasets,
    alpha0_mu =  alpha0_mu_prior_mu,
    alpha0_sigma = alpha0_sigma_prior_sd,
    alphaD_mu = alphaD_mu_prior_mu,
    alphaD_sigma = alphaD_sigma_prior_sd,
    beta0_mu = beta0_mu_prior_mu,
    beta0_sigma = beta0_sigma_prior_sd,
    betaD_mu = betaD_mu_prior_mu,
    betaD_sigma = betaD_sigma_prior_sd,
    nsubj = params$nsubj_sim,
    nobs_per_cond = params$nobs_per_cond_sim
  )
  
  sim_datasets <- 
    sim_priors %>%
    mutate(
      # use draw_subj to sample nsubj_sim per sim using group-level parameter draws
      dataset = pmap(sim_priors, draw_subj),
      stimulation = list(stimulation = rep(conditions, each = nobs_per_cond))) %>%
    
    # first unnest dataset, expanding by nsubj_sim and copying stimulation list to each subj
    unnest(dataset) %>%
    
    # then unnest stimulation, expanding by nobs_per_cond_sim*2
    unnest(stimulation) %>%
    
    # now use likelihood to simulation observations
    mutate(
      # evaluate and delink linear model on pMem
      pMem = inv_logit(subj_beta0 + (subj_betaD * stimulation)),
      
      # evaluate and delink linear model on circSD/kappa
      k = sd2k_vec(
        pracma::deg2rad(
          exp(subj_alpha0 + (subj_alphaD * stimulation)))),
      
      # use pMem to draw a 1 or 0 for each trial
      memFlip = rbernoulli(n(), pMem),
      
      # use k to draw from vonMises for each trial
      vm_draw = rvonmises_vec(1, pi, k) - pi,
      
      # draw from unif for each trial
      unif_draw = runif(n(), -pi, pi),
      
      # assign either vm_draw or unif_draw to each trial, depending on memFlip
      obs_radian = memFlip * vm_draw + (1 - memFlip) * unif_draw,
      
      # convert to degrees
      obs_degree = obs_radian * (180/pi)
    ) %>%
    select(-c(pMem, k, memFlip, vm_draw, unif_draw)) %>%
    nest(subj_obs = c(stimulation, obs_degree, obs_radian)) %>%
    nest(dataset = c(subj, subj_alpha0, subj_alphaD, subj_beta0, subj_betaD, nobs_per_condition, subj_obs))

  
  saveRDS(sim_datasets, file = sim_datasets_fpath)

}

obs_only <- 
  sim_datasets %>% 
  unnest(dataset) %>% 
  unnest(subj_obs) %>% 
  select(c(subj, stimulation, obs_degree, error = obs_radian)) %>% 
  mutate(subj = as_factor(subj))

peek at data

obs_only %>% 
  filter(stimulation == 0) %>%
  ggplot(aes(x = obs_degree)) +
  geom_histogram(binwidth = 10, aes(y=..density..)) + 
  geom_rug() + 
  geom_density(aes(y=..density..)) +  
  facet_wrap(vars(subj), ncol = 1)

obs_only %>% 
  filter(stimulation == 1) %>%
  ggplot(aes(x = obs_degree)) +
  geom_histogram(binwidth = 10, aes(y=..density..)) + 
  geom_rug() + 
  geom_density(aes(y=..density..)) +  
  facet_wrap(vars(subj), ncol = 1)

fit brms

iter = 6000
warmup = 3000
cores = 4
chains = 4
n_post_samples = (iter - warmup) * chains

writeLines(
  make_stancode(bform_full, obs_only, family = vm_uniform_mix, prior = bprior_full, stanvars = stanvars),
  glue("{params$save_dir_str}/stan_code.txt")
)

model_fit <- brm(bform_full, obs_only, family = vm_uniform_mix, prior = bprior_full, stanvars = stanvars,
                 sample_prior = "yes",
                 warmup = warmup, iter = iter, cores = cores, chains = chains, 
                 control = list(adapt_delta = 0.99), inits = 0, 
                 file = glue("{params$save_dir_str}/sim_model_fit"))

print(model_fit)
## Warning: There were 9 divergent transitions after warmup. Increasing adapt_delta above 0.99 may help.
## See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
##  Family: vm_uniform_mix 
##   Links: mu = identity; circSD = log; theta = logit; a = identity; b = identity 
## Formula: error ~ 0 
##          circSD ~ 0 + intercept + stimulation + (1 + stimulation || subj)
##          theta ~ 0 + intercept + stimulation + (1 + stimulation || subj)
##          a = -3.14
##          b = 3.14
##    Data: obs_only (Number of observations: 756) 
## Samples: 4 chains, each with iter = 6000; warmup = 3000; thin = 1;
##          total post-warmup samples = 12000
## 
## Group-Level Effects: 
## ~subj (Number of levels: 3) 
##                        Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## sd(circSD_Intercept)       0.66      0.23     0.31     1.20 1.00     8102
## sd(circSD_stimulation)     0.33      0.24     0.02     0.91 1.00     6067
## sd(theta_Intercept)        1.00      0.44     0.33     2.04 1.00     6588
## sd(theta_stimulation)      0.58      0.46     0.02     1.74 1.00     6597
##                        Tail_ESS
## sd(circSD_Intercept)       7553
## sd(circSD_stimulation)     5548
## sd(theta_Intercept)        4911
## sd(theta_stimulation)      5040
## 
## Population-Level Effects: 
##                    Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## circSD_intercept       3.49      0.33     2.89     4.18 1.00     6341
## circSD_stimulation     0.19      0.25    -0.29     0.71 1.00     6338
## theta_intercept       -0.00      0.63    -1.29     1.24 1.00     6369
## theta_stimulation     -0.16      0.54    -1.28     0.91 1.00     7168
##                    Tail_ESS
## circSD_intercept       7491
## circSD_stimulation     6810
## theta_intercept        6969
## theta_stimulation      6526
## 
## Family Specific Parameters: 
##                      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
## b_circSD_intercept       3.49      0.33     2.89     4.18 1.00     6341
## b_circSD_stimulation     0.19      0.25    -0.29     0.71 1.00     6338
## b_theta_intercept       -0.00      0.63    -1.29     1.24 1.00     6369
## b_theta_stimulation     -0.16      0.54    -1.28     0.91 1.00     7168
##                      Tail_ESS
## b_circSD_intercept       7491
## b_circSD_stimulation     6810
## b_theta_intercept        6969
## b_theta_stimulation      6526
## 
## Samples were drawn using sampling(NUTS). For each parameter, Eff.Sample 
## is a crude measure of effective sample size, and Rhat is the potential 
## scale reduction factor on split chains (at convergence, Rhat = 1).
sim_datasets %>% 
  unnest(dataset) %>% glimpse()
## Observations: 3
## Variables: 18
## $ sim_num            <int> 1, 1, 1
## $ alpha0_mu          <dbl> 3.8, 3.8, 3.8
## $ alpha0_sigma       <dbl> 0.5, 0.5, 0.5
## $ alphaD_mu          <dbl> 0, 0, 0
## $ alphaD_sigma       <dbl> 0.5, 0.5, 0.5
## $ beta0_mu           <dbl> 0, 0, 0
## $ beta0_sigma        <dbl> 1, 1, 1
## $ betaD_mu           <dbl> 0, 0, 0
## $ betaD_sigma        <dbl> 1, 1, 1
## $ nsubj              <int> 3, 3, 3
## $ nobs_per_cond      <int> 126, 126, 126
## $ subj               <int> 1, 2, 3
## $ subj_alpha0        <dbl> 4.336359, 3.887744, 2.771223
## $ subj_alphaD        <dbl> -0.3885006, 0.6059154, -0.1587515
## $ subj_beta0         <dbl> -0.7107940, -0.9290677, 1.1493721
## $ subj_betaD         <dbl> -0.6139408, -0.1176994, -0.2847914
## $ nobs_per_condition <int> 126, 126, 126
## $ subj_obs           <S3: vctrs_list_of> 0.00000000, 0.00000000, 0.00000…

fit check

divergences

#check neff and rhat and divergences
np <- nuts_params(model_fit)
rhat <- brms::rhat(model_fit)
neff_rat <- neff_ratio(model_fit)

np %>% 
  filter(Parameter == "divergent__") %>%
  summarise(n_div = sum(Value))
##   n_div
## 1     9

rhat

mcmc_rhat(rhat) + yaxis_text(hjust = 1) + scale_x_continuous(breaks = pretty_breaks(6))
## Scale for 'x' is already present. Adding another scale for 'x', which
## will replace the existing scale.

neff ratio

mcmc_neff(neff_rat) + yaxis_text(hjust = 1)

trace plots

mcmc_trace(as.array(model_fit$fit))

other

plot(model_fit, ask = FALSE)

plot posteriors

arrange samples

# compute summaries for plot

group_level_samples <- 
  spread_draws(model_fit, `(b|sd)_.*`, regex = TRUE) %>%
  mutate(
         # group level parameters
         circSD_pre_mean  = exp(b_circSD_intercept),
         circSD_post_mean = exp(b_circSD_intercept + b_circSD_stimulation),
         circSD_ES_mean   = circSD_post_mean - circSD_pre_mean,
         pMem_pre_mean    = inv_logit(b_theta_intercept),
         pMem_post_mean   = inv_logit(b_theta_intercept + b_theta_stimulation),
         pMem_ES_mean     = pMem_post_mean - pMem_pre_mean,
         # predicitve dist for group level parameters
         circSD_pre_pred  = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept)),
         circSD_post_pred = exp(rnorm(n(), b_circSD_intercept, sd_subj__circSD_Intercept) + 
                                rnorm(n(), b_circSD_stimulation, sd_subj__circSD_stimulation)),
         circSD_ES_pred   = circSD_post_pred - circSD_pre_pred,
         pMem_pre_pred    = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept)),
         pMem_post_pred   = inv_logit(rnorm(n(), b_theta_intercept, sd_subj__theta_Intercept) +
                                    rnorm(n(), b_theta_stimulation, sd_subj__theta_stimulation)),
         pMem_ES_pred     = pMem_post_pred - pMem_pre_pred
         ) %>% 
  select(-contains("b_"), -contains("sd_subj")) %>%
  pivot_longer(-contains("."), names_to = c("param", "stat"), names_pattern = "(.*)_(.*)", values_to = "value") %>%
  pivot_wider(names_from = stat, values_from = value)
  
group_level_summary <- 
  group_level_samples %>%
  group_by(param) %>%
  median_qi(.width = c(.5, .8, .95))



circSD_subj_samples <- 
  model_fit %>% 
  spread_draws(b_circSD_intercept, b_circSD_stimulation, r_subj__circSD[subj, term]) %>%
  ungroup() %>%
  pivot_wider(names_from = term, values_from = r_subj__circSD, names_prefix = "offset_") %>%
  mutate(subj = subj,
        circSD_pre = exp(b_circSD_intercept + offset_Intercept),
        circSD_post = exp(b_circSD_intercept + offset_Intercept + b_circSD_stimulation + offset_stimulation),
        circSD_ES = circSD_post - circSD_pre) %>%
  select(-c(b_circSD_intercept, offset_Intercept, b_circSD_stimulation, offset_stimulation)) %>%
  pivot_longer(contains("circSD"), names_to = "param", values_to = "value") 

circSD_subj_summary <- 
  circSD_subj_samples %>%
  group_by(subj, param) %>%
  median_qi(.width = c(.90, .95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed

## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
pMem_subj_samples <- 
  model_fit %>%
  spread_draws(b_theta_intercept, b_theta_stimulation, r_subj__theta[subj, term]) %>%
  ungroup() %>%
  pivot_wider(names_from = term, values_from = r_subj__theta, names_prefix = "offset_") %>%
  mutate(subj = subj,
            pMem_pre = inv_logit(b_theta_intercept + offset_Intercept),
            pMem_post = inv_logit(b_theta_intercept + offset_Intercept + b_theta_stimulation + offset_stimulation),
            pMem_ES = pMem_post - pMem_pre) %>%
  select(-c(b_theta_intercept, offset_Intercept, b_theta_stimulation, offset_stimulation)) %>%
  pivot_longer(contains("pMem"), names_to = "param", values_to = "value") 

pMem_subj_summary <- 
  pMem_subj_samples %>%
  group_by(subj, param) %>%
  median_qi(.width = c(.90, .95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed

## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
group_level_samples %>%
  select(-pred) %>%
  group_by(param) %>%
  median_qi(.width = c(.95))
## Warning: unnest() has a new interface. See ?unnest for details.
## Try `df %>% unnest(c(.lower, .upper))`, with `mutate()` if needed
## # A tibble: 6 x 7
##   param           mean  .lower .upper .width .point .interval
##   <chr>          <dbl>   <dbl>  <dbl>  <dbl> <chr>  <chr>    
## 1 circSD_ES     6.09    -8.84  36.7     0.95 median qi       
## 2 circSD_post  39.3     18.5   89.2     0.95 median qi       
## 3 circSD_pre   32.2     18.0   65.3     0.95 median qi       
## 4 pMem_ES      -0.0330  -0.277  0.201   0.95 median qi       
## 5 pMem_post     0.458    0.147  0.813   0.95 median qi       
## 6 pMem_pre      0.499    0.216  0.776   0.95 median qi

group level posteriors

circSD_p1 <- group_level_samples %>% 
  filter(str_detect(param, "circSD")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = circSD_subj_summary, aes(y = param, x = value), size = 2) +
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "circSD: group level mean posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "circSD", 
       color = "interval")

# group level pMem pre, post and ES plot

pMem_p1 <- group_level_samples %>% 
  filter(str_detect(param, "pMem")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior predictive distribution for group means
  stat_intervalh(aes(y = param, x = pred), .width = c(.5, .8, .95)) +
  # posterior medians for each parameter estimate per subj
  geom_point(data = pMem_subj_summary, aes(y = param, x = value), size = 2) +
  # decorations
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  #coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "pMem: group level mean posterior (median, 90%, 95% interval), \nsubject posterior medians, \ncondition predictive dist of subjects", 
       x = "pMem", 
       color = "interval")

plot_grid(circSD_p1, pMem_p1, align = "hv", ncol = 1)

circSD pre, post, ES

# circSD pre: group level posteriors and subject posteriors

circSD_p2 <- 
  ggplot() + 
  # plot group mean circSD_pre posterior and subject circSD_pre posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_pre")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_pre")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot pre condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_pre")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_pre posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_pre"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y = -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_pre: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "circSD",
       color = "interval")



circSD_p3 <-   
  ggplot() + 
  # plot group mean circSD_post posterior and subject circSD_post posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_post")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_post")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot post condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_post")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_post posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_post"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_post: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "circSD",
       color = "interval")


circSD_p4 <- 
  ggplot() + 
  # plot group mean circSD_ES posterior and subject circSD_ES posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "circSD_ES")) %>%
                              select(-pred, value = mean)
                            ,
                              circSD_subj_samples %>%
                              filter(str_detect(param, "circSD_ES")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot ES  group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "circSD_ES")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject circSD_ES posterior medians in the prediction band
  geom_point(data = circSD_subj_summary %>% filter(param == "circSD_ES"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "circSD_ES: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \nES predictive dist of subjects",
       x = "detla circSD",
       color = "interval")


plot_grid(circSD_p2, circSD_p3, circSD_p4, ncol = 1, align = "hv")

pMem pre, post, ES

# pMem_ pre: group level posteriors and subject posteriors

pMem_p2 <- 
  ggplot() + 
  # plot group mean pMem__pre posterior and subject cpMem__pre posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_pre")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_pre")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot pre condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_pre")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_pre posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_pre"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y = -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_pre: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "pMem",
       color = "interval")



pMem_p3 <-   
  ggplot() + 
  # plot group mean pMem_post posterior and subject pMem_post posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_post")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_post")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot post condition group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_post")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_post posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_post"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_post: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \ncondition predictive dist of subjects",
       x = "pMem",
       color = "interval")


pMem_p4 <- 
  ggplot() + 
  # plot group mean pMem_ES posterior and subject pMem_ES posterior
  geom_halfeyeh(data = rbind(
                              group_level_samples %>% 
                              filter(str_detect(param, "pMem_ES")) %>%
                              select(-pred, value = mean)
                            ,
                              pMem_subj_samples %>%
                              filter(str_detect(param, "pMem_ES")) %>%
                              unite(param, param, subj) )
                , aes(y = param, x = value), .width = c(.90, .95)) + 
  # plot ES  group predictive distribution
  stat_intervalh(data = group_level_samples %>% filter(str_detect(param, "pMem_ES")),
                 aes(y = param , x = pred),
                 .width = c(.5, .8, .95),
                 position = position_nudge(y = -0.15)
                   ) +
  # show modeled subject pMem_ES posterior medians in the prediction band
  geom_point(data = pMem_subj_summary %>% filter(param == "pMem_ES"),
             aes(y = param, x = value), 
             size = 2, 
             position = position_nudge(y =  -0.15)) + 
  scale_color_brewer() + 
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  labs(subtitle = "pMem_ES: group level mean posterior (median, 90%, 95% interval), \nsubject posterior, \nES predictive dist of subjects",
       x = "detla pMem",
       color = "interval")
plot_grid(pMem_p2, pMem_p3, pMem_p4, ncol = 1, align = "hv")

group level joint posteriors

group_level_samples %>%
  pivot_wider(id_cols = contains("."), names_from = param, values_from = mean) %>%
  select(-contains(".")) %>%
  mcmc_pairs(off_diag_fun = "hex")
## Warning: Only one chain in 'x'. This plot is more useful with multiple
## chains.

parameter recovery

prior_summary <- 
  sim_datasets %>% 
  transmute(circSD_pre = exp(alpha0_mu),
            circSD_post = exp(alpha0_mu + alphaD_mu),
            circSD_ES = circSD_post - circSD_pre,
            pMem_pre = exp(beta0_mu)/(exp(beta0_mu) + 1),
            pMem_post = exp(beta0_mu + betaD_mu)/(exp(beta0_mu + betaD_mu) + 1),
            pMem_ES = pMem_post - pMem_pre) %>%
  pivot_longer(everything(), names_to = "param", values_to = "value")

circSD_p1 <- group_level_samples %>% 
  filter(str_detect(param, "circSD")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior medians for each parameter estimate per subj
  geom_point(data = prior_summary %>% filter(str_detect(param, "circSD")), aes(y = param, x = value, color = "red"), size = 2) +
  # decorations
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "circSD: group level mean posterior (median, 90%, 95% interval), \ntrue param value = red dot", 
       x = "circSD")

# group level pMem pre, post and ES plot

pMem_p1 <- group_level_samples %>% 
  filter(str_detect(param, "pMem")) %>%
  ggplot() + 
  # posterior dist + interval for group mean
  geom_halfeyeh(aes(y = param, x = mean), .width = c(.90, .95), position = position_nudge(y = 0.15)) + 
  # posterior medians for each parameter estimate per subj
  geom_point(data = prior_summary %>% filter(str_detect(param, "pMem")), aes(y = param, x = value, color = "red"), size = 2) +
  # decorations
  scale_x_continuous(breaks = pretty_breaks(10)) + 
  #coord_cartesian(xlim=c(-50, 150)) +
  labs(subtitle = "pMem: group level mean posterior (median, 90%, 95% interval), \ntrue param value = red dot", 
       x = "pMem")

plot_grid(circSD_p1, pMem_p1, align = "hv", ncol = 1)